Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add JAX Support to NPBench and Implement JAX Benchmarks #31

Open
wants to merge 114 commits into
base: main
Choose a base branch
from

Conversation

hardik01shah
Copy link

Overview:

This PR introduces JAX as a supported framework in NPBench, providing JAX implementations for all existing benchmarks. This enhancement allows NPBench to compare the performance of JAX against other frameworks, for all implemented benchmarks in NPBench.


Changes Introduced:

  1. Framework Configuration:
    • Added framework_info/jax.json to define framework-specific configuration settings for JAX.
  2. Benchmark Implementations:
    • Added JAX implementations for all benchmarks in the directory structure:
      • npbench/benchmarks/<bench_parent>/<bench_name>_jax.py.
    • Support for benchmarks with existing JAX library implementations.
      • npbench/benchmarks/<bench_parent>/<bench_name>_jax_lib.py
  3. Infrastructure Updates:
    • Updated npbench/infrastructure/__init__.py to include JAX framework initialization.
    • Added npbench/infrastructure/jax_framework.py to handle:
      • Pre- and post-processing for JAX array arguments (e.g., using block_until_ready() and jnp.array()).
  4. Bug Fix:
    • Resolved a minor bug in jax_framework.py related to incorrect naming of the current framework during benchmarking.

Motivation:

The addition of JAX support enhances NPBench by:

  • Allowing direct comparison of JAX’s performance with other frameworks across diverse scientific Python benchmarks.
  • Extending the utility of NPBench to the JAX community for performance evaluation and optimization.

Testing and Validation:

  • Verified the correctness of JAX implementations by validating the implementations with the Numpy impplementations.
  • Ensured smooth integration of JAX into the benchmarking pipeline.

Contributors:

hardik01shah and others added 30 commits October 29, 2024 10:24
Define the JaxFramework class for implementing block_until_ready() calls after kernel computation for correctness of profiling, convert np.ndarray to jax Array in the copy_func()
Define the JaxFramework class for implementing block_until_ready() calls after kernel computation for correctness of profiling, convert np.ndarray to jax Array in the copy_func()
sushant1212 and others added 30 commits November 23, 2024 14:56
Previously >10x slower, now around same runtime as numpy.
Previously was up to 70x slower, now it's faster for smaller sizes and comparable to numpy for bigger ones.
Previously was up to 90x slower, now it's up to 40x faster than numpy.
Previously was up to 3x slower, now comparable to numpy.
Add JAX as a supported framework
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants